import json
import logging
import os
import random
import shutil
from argparse import ArgumentParser

import gin
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter

from advTrainer import AdvTrainer
from config import PROJECT_ROOT
from dataset.build_dataset import build_dataset
from model.create_model import create_model
from util.optimizer_scheduler import build_optimizer, build_scheduler


# Function to seed everything for reproducibility
def seed_everything(seed=0):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    random.seed(seed)
    print(seed)

# Function for DataLoader worker initialization
def worker_init_fn(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

def save_random_state():
    return torch.get_rng_state(), np.random.get_state()

def restore_random_state(torch_state, numpy_state):
    torch.set_rng_state(torch_state)
    np.random.set_state(numpy_state)

def get_random_orthogonal_directions(model):
    torch_state, numpy_state = save_random_state()
    
    params = [p.data for p in model.parameters() if p.requires_grad]
    dir1 = [torch.randn_like(p) for p in params]
    dir2 = [torch.randn_like(p) for p in params]

    # Orthogonalize dir2 with respect to dir1
    for d1, d2 in zip(dir1, dir2):
        proj = torch.dot(d1.view(-1), d2.view(-1)) / torch.dot(d1.view(-1), d1.view(-1))
        d2.add_(-proj * d1)
    restore_random_state(torch_state, numpy_state)
    return dir1, dir2

@gin.configurable
def main(parameter, seed, dataset, epoch, restore_ckpt=None, aux_dataset_path=None):
    model0 = create_model(dataset_name=dataset)
    dir1, dir2 = get_random_orthogonal_directions(model0)
    print("seed:")
    print(seed)
    seed_everything(int(seed))
    
    os.environ["CUDA_VISIBLE_DEVICES"] = "%s" % (parameter["cuda"])
    torch.multiprocessing.set_sharing_strategy("file_system")
    dist_training = False
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    

    if dataset == "cifar10":
        n_class = 10
    elif dataset == "cifar100":
        n_class = 100
    elif dataset == "tiny-imagenet":
        n_class = 200
    else:
        raise NotImplementedError("No such dataset: %s" % (dataset))

    train_set = build_dataset(
        dataset_name=dataset,
        is_train=True,
    )
    val_set = build_dataset(
        dataset_name=dataset,
        is_train=False,
    )

    train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=parameter["batch_size"],
        num_workers=parameter["num_workers"],
        shuffle=True,
        drop_last=True,
        worker_init_fn=worker_init_fn,
    )
    val_loader = torch.utils.data.DataLoader(
        val_set,
        batch_size=parameter["batch_size"],
        num_workers=parameter["num_workers"],
        shuffle=False,
        worker_init_fn=worker_init_fn,
    )
    if aux_dataset_path is not None:
        aux_set = build_dataset(
            dataset_name=dataset,
            is_train=True,
            aux_dataset_path=aux_dataset_path,
        )
        aux_loader = torch.utils.data.DataLoader(
            aux_set,
            batch_size=parameter["aux_batch_size"],
            num_workers=parameter["num_workers"],
            shuffle=True,
            drop_last=True,
            persistent_workers=True,
            worker_init_fn=worker_init_fn,
        )
    else:
        aux_loader = None

    model = create_model(dataset_name=dataset)
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)

    optimizer = build_optimizer(model, use_extra_data=aux_loader is not None)
    scheduler = build_scheduler(optimizer, epoch, len(train_loader))

    start_epoch = 0
    model_path = "<model_path_for_restore>"
    if restore_ckpt is not None:
        print("restoring")
        ckpt = torch.load(model_path)
        optimizer = ckpt["optimizer"]
        scheduler = ckpt["scheduler"]
        start_epoch = ckpt["epoch"]
        model.load_state_dict(ckpt["model"])

    EXP_DIR = os.path.join(PROJECT_ROOT, dataset + "_experiment")
    print(EXP_DIR)

    if not os.path.exists(EXP_DIR):
        os.mkdir(EXP_DIR)

    
    # ckpt_path = os.path.join(EXP_DIR, parameter["description"])
    withoutdot = parameter["description"][2:]
    ckpt_path = os.path.join(EXP_DIR, withoutdot)
    print("ckpt")
    print(ckpt_path)
    writer = None
    hparam = {
        "description": parameter["description"],
        "batch_size": parameter["batch_size"],
        "aux_batch_size": parameter["aux_batch_size"],
        "dataset": dataset,
        "n_class": n_class,
        "solution": parameter["solution"],
    }

    logging.info("Using device %s" % (device))
    logging.info("Dataset: %s" % (dataset))
    logging.info("Total number of params: %d" % (n_parameters))
    logging.info("create model class %d" % n_class)

    if not os.path.exists(ckpt_path):
        print(ckpt_path)
        os.mkdir(ckpt_path)
    json.dump(
        parameter,
        open(os.path.join(ckpt_path, "config.json"), "w"),
        indent=4,
        sort_keys=True,
    )
    shutil.copy(parameter["gin_config"], os.path.join(ckpt_path, "config.gin"))
    writer = SummaryWriter(
        os.path.join(EXP_DIR, "runs%s" % (parameter["description"]))
    )
    logging.info("Start Training")

    trainer = AdvTrainer(
        device,
        hparam,
        use_ema=parameter["ema"],
        aux_loader=aux_loader,
    )
    trainer.train(
        model,
        train_loader,
        val_loader,
        optimizer,
        scheduler,
        ckpt_path,
        dir1,
        dir2,
        writer,
        epoch,
        start_epoch=start_epoch,
        runs_path=os.path.join(EXP_DIR, "runs%s" % (parameter["description"])),
    )


def _parse_argument():
    parser = ArgumentParser()
    parser.add_argument(
        "--gin_config",
        type=str,
        required=True,
    )
    parser.add_argument(
        "--restore_ckpt",
        type=str,
        default=None,
        help="Path to restore from the checkpoint trained half way",
    )
    parser.add_argument("--description", type=str, required=True)
    parser.add_argument("--cuda", type=str, default="4")
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--aux_batch_size", type=int, default=0)
    parser.add_argument("--num_workers", type=int, default=16)
    parser.add_argument("--ema", action="store_true")
    parser.add_argument("--seed", type=int, default=42, help="Seed for random number generators")
    parser.add_argument("--solution", action="store_true", help="Pass the solution flag to AdvTrainer")
    args = parser.parse_args()
    return args


if __name__ == "__main__":
    loglevel = os.environ.get("LOGLEVEL", "INFO").upper()
    logging.basicConfig(
        format="%(asctime)s | %(levelname)s | %(message)s",
        level=loglevel,
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    args = _parse_argument()
    parameter = vars(args)
    parameter["description"] = args.description
    parameter["cuda"] = args.cuda
    parameter["gin_config"] = os.path.join(PROJECT_ROOT, parameter["gin_config"])
    parameter["seed"] = args.seed
    gin.parse_config_files_and_bindings([parameter["gin_config"]], None)
    main(parameter, restore_ckpt=args.restore_ckpt, seed=parameter["seed"])
